Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pseudo Huber loss for Flux and SD3 #1808

Merged
merged 3 commits into from
Dec 1, 2024

Conversation

recris
Copy link

@recris recris commented Nov 27, 2024

This change set implements pseudo Huber loss functionality that was missing from Flux.1 and SD3 model training.

This also introduces a new parameter huber_scale for controlling base threshold of the Huber function. The previous logic had a fixed threshold that worked very poorly with Flux.1 models; This is something that needs to be tuned to the loss "profile" of each model (mean abs latent error, variance, etc.). The existing huber_c could only be used for adjusting the decay of the exponential schedule, so a new parameter was required.

Refactoring notes:

  • Moved most of the loss estimation logic inside the conditional_loss function, huber_c is no longer passed through the rest of the code; most of the file changes were due to this.
  • Added an explicit error message when "snr" schedule is used with a model currently incompatible with it.

Tested with:

  • SDXL Lora
  • Flux.1 Dev Lora
  • SD3.5 Medium Lora (only checked that it doesn't crash)

For use with Flux.1 I recommend using huber_scale = 1.8 (or a bit higher) and huber_c = 0.25. YMMV.

@recris recris mentioned this pull request Nov 27, 2024
25 tasks
@araleza
Copy link

araleza commented Nov 27, 2024

This is great, I was literally just considering asking for Huber loss to be implemented for kohya Flux somehow. It was one of the biggest steps forward for my SDXL training. I'll try to give it a go soon.

Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, this is great. I think it makes sense to split get_timesteps_and_huber_c.

However, regarding the responsibilities of the methods, I personally don't want to pass args to conditional_loss. We may refactor it after merging to call get_timesteps, get_huber_threshold, and conditional_loss in order from each script. We appreciate your understanding.

fine_tune.py Outdated Show resolved Hide resolved
library/train_util.py Outdated Show resolved Hide resolved
@recris
Copy link
Author

recris commented Nov 28, 2024

Thank you, this is great. I think it makes sense to split get_timesteps_and_huber_c.

However, regarding the responsibilities of the methods, I personally don't want to pass args to conditional_loss. We may refactor it after merging to call get_timesteps, get_huber_threshold, and conditional_loss in order from each script. We appreciate your understanding.

I also do not like the current approach, there should be a "step context" object that collects information about the training context, then it would get passed to all the various methods/strategy/utility functions instead of having to pass around a handful of separate parameters everywhere. But that is a bigger refactoring than I'd like to make here.

sd3_train.py Outdated Show resolved Hide resolved
@@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)


def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
def get_timesteps(min_timestep, max_timestep, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this work under the current implementation of sd3 in this repo? haven't looked if it was updated for a while.
but as far as i can recall, the scheduling was implemented in this repo, and in different way and didnt have the 'add_noise' like other implementations from diffusers.
so essentially, it was randomizing the index for timestep, but then taking the timestep itself from the noise_scheduler and calaculating the noise using sigmas.
maybe this changed, haven't looked in the repo. but worth a double check

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that is an issue, then it is separate from the scope of this PR.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_noise can be found here and is not significantly different from other repositories.

noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

Only SD3/3.5 (sd3_train.py or sd3_train_network.py) doesn't use get_noise_noisy_latents_and_timesteps().

@araleza
Copy link

araleza commented Nov 30, 2024

I gave it a try... I'm not seeing good results so far. :( I went for:

--loss_type huber --huber_c 0.25 --huber_scale 1.8 --huber_schedule="exponential" 

Learning with Huber loss enabled seemed a lot slower for the same LR, and I didn't see any big image quality improvements either. I pushed up the LR, but that just got me the usual grid-shaped noise pattern. I was using a gradient accumulation size of 16.

Maybe it's dataset-specific, and it'll work better for someone else? My training dataset is around 200 images in size, mostly high quality, but with some poorer quality ones mixed in, so I hoped Huber would help with that.

@recris
Copy link
Author

recris commented Nov 30, 2024

I gave it a try... I'm not seeing good results so far. :( I went for:

--loss_type huber --huber_c 0.25 --huber_scale 1.8 --huber_schedule="exponential" 

Learning with Huber loss enabled seemed a lot slower for the same LR, and I didn't see any big image quality improvements either. I pushed up the LR, but that just got me the usual grid-shaped noise pattern. I was using a gradient accumulation size of 16.

Maybe it's dataset-specific, and it'll work better for someone else? My training dataset is around 200 images in size, mostly high quality, but with some poorer quality ones mixed in, so I hoped Huber would help with that.

I've had to increase the LR as well in my tests, typically by around 20-50%, depending on huber_scale used. This is kind of expected given this function effectively dampens the loss on larger error values in the latent, and as a consequence the gradient magnitude is on average lower. Theoretically, lower huber_scale means more more dampening on the learning - we are effectively adjusting a tradeoff between L1 and L2 loss.

I am not sure about the best huber_scale for a general recommendation, values I suggested worked OK for me but my tests were limited and the datasets were not representative.

If the learning is not progressing as fast as before I recommend increasing huber_scale.

@recris
Copy link
Author

recris commented Nov 30, 2024

Also, from past experience with this type of loss function in other models I don't expect this to improve image details quality much, this usually helps more with getting consistency in the outputs.

@araleza
Copy link

araleza commented Nov 30, 2024

I have some better news to report now. Without changing anything from the settings I gave above, I've let the training continue longer. And the image quality has been rising steadily for a while now, and I'm getting some nice output. Maybe Huber just has different learning characteristics.

@araleza
Copy link

araleza commented Nov 30, 2024

I'm seeing a little bit of grid artifact creeping back in though (while training at 5-e4, grad_accum 16). I don't know why it is, that when my image quality gets to be very good, that DiT artifact appears. Maybe we need block-specific learning rates? I don't fully understand the problem though.

This is the grid artifact I'm referring to (although this one seems mostly vertical):
image

It seems to be particularly bad at the right hand edge of the images for some reason, although I can see it faintly across the whole image sometimes.

@recris
Copy link
Author

recris commented Nov 30, 2024

I usually observe this kind of thing when pushing LoRA strength way above 1.0 during inference, or when applying multiple LoRAs simultaneously without lowering the strength of each. I suspect this is related to having very large weight magnitude due to stacking, or over-training.

Training at 5e-4 seems super high to me, I usually don't go above 2e-4.

What are your average_key_norm and max_key_norm values (in tensorboard)? These give an indication on how much the trained model has "deviated" from the base model.

@araleza
Copy link

araleza commented Nov 30, 2024

I usually have a lower LR, but once the gradient accumulation value goes up, I find I usually have to push the LR higher, as I think the gradient update is divided by the number of accumulations, unlike batch size.

And higher LRs actually work amazingly well with Flux. The images I've been getting have rich colors and interesting camera angles (i.e. the Flux DPO seems to still be active), but also include the trained objects. They also have rich backgrounds with many scene-appropriate items appearing, a sign of network health. The pictures I get out from flux_minimal_inference.py (which is what I use for sample images) are honestly the kind of thing I'd put up on my art page, if it wasn't for the grid artifact appearing on them.

Lower LRs haven't worked so well, as the trained objects don't get learned, and my images get boring colors and camera angles, showing the DPO has been lost. I get the feeling local minima are a real problem with image training, and the higher LRs can power through those.

I really should get tensorboard up and running with my runs. I do wonder if maybe what we need for Flux is the ability to set the LR individually for each of the 19 double and the 38 single blocks. Maybe the average_key_norm and max_key_norm sizes would guide me towards which ones.

Edit: I'm trying a 1e-4 run now, still with grad_accum=16. Maybe this LR will work well with Huber loss.

@recris
Copy link
Author

recris commented Nov 30, 2024

I've been told that high gradient accumulation (with 16-bit floats) can cause numerical instability and/or loss of precision, not sure if that applies here.

From tests I've made usually these types of visual artifacts tend to appear when average_key_norm starts to go above ~1.2 (in datasets with around 50 to 100 images). Usually I try to reduce the LR and see if it converges to optimal results in the same number of steps. average_key_norm tends to grow proportional to the LR used, but this is affected by the size and diversity of the dataset, with larger ones it seems to grow more slowly. You could also try to use max norm regularization to penalize weights growing above a certain threshold, by setting scale_weight_norms to something like 4.0 or 5.0.

I suspect it would be even more effective to set distinct scale_weight_norms at block level but that is currently not implemented.

@kohya-ss kohya-ss changed the base branch from sd3 to flux-huber-loss December 1, 2024 04:15
@kohya-ss kohya-ss merged commit a5a27fe into kohya-ss:flux-huber-loss Dec 1, 2024
1 check passed
@kohya-ss kohya-ss mentioned this pull request Dec 1, 2024
@araleza
Copy link

araleza commented Dec 1, 2024

Okay, so I have so more good news: my training run at 1e-4 with @recris's Huber loss went very well. Now I have some very high quality images, with the concept in the training images learned close to perfectly. (It's still training).

The big thing to note is how long it took. Training with gradient accumulation 16 at 1e-4 took pretty much a whole day of training to get it to work. But it did make continuous steady progress throughout this time. Without Huber loss, I either found the learning of the training concept would get stuck and sample images would continually mis-draw the objects being learned, or that the sample images would lose their high image quality and draw boring scenes. That doesn't seem to be the case with Huber loss.

@recris, I've heard things about high gradient accumulation being bad too, but it hasn't seem to have affected me here. Before Huber loss was implemented, I pushed up from accumulating 8 to accumulating 16 to try to reach better image quality, and that did seem to work. But yeah, gradient accumulation feels a bit suspect in general compared to increasing batch size. It does use less memory though, and another avantage to gradient accumulation that people often miss is that batch size only works on images that are in the same bucket for image dimensions. And often buckets only have single images in them, silently preventing batch size from working for them. Gradient accumulation is immune to that.

Anyway, I'm glad that this PR has made it into the SD3 branch, as it seems worthwhile to me.

@StableLlama
Copy link

All insights about how to use it are great - but having them in this PR makes it very hard to use them in future.
So please put it also in a documentation file so that it can be looked up anytime in the future!

@FurkanGozukara
Copy link

@bmaltais can we have this i would like to test this ty

@araleza
Copy link

araleza commented Dec 11, 2024

You could also try to use max norm regularization to penalize weights growing above a certain threshold, by setting scale_weight_norms to something like 4.0 or 5.0.

Hey there @recris, I gave this a try (with 4.0), and yeah that lets me get to the higher LRs without those vertical lines, with the target object being learned at higher quality. Great suggestion.

I hope we can get individual LRs for individual blocks at some point, like we had for SDXL with --block_lr=.... That would probably mean that whatever keys are growing very fast wouldn't get so big and need clipped in length. But in the meantime, renormalizing them seems to do the job.

@araleza
Copy link

araleza commented Dec 13, 2024

Okay, I found something I think is pretty cool, and I thought I'd put it here cause it relates to what's been said so far.

I found that using --scale_weight_norms did get rid of the vertical lines on the image when I used higher LRs... but it also seemed to limit the amount of learning that occurred. The object I was trying to learn didn't contain all the fine details that I wanted.

I tried pushing up the value for --scale_weight_norms, but then my images became 'constrasty', with faces often being bright and washed out.

I suspected that the issue was that the later single blocks were getting overtrained, and they can't deal with the higher key lengths - but the earlier double blocks were not overtrained and needed the longer key lengths available to them to be able to learn the trained object.

So I added this code to lora_flux.py: (round about line 1145)

            # Attempt a different max scaling for single blocks
            if downkeys[i].find('single') != -1:
                this_max_norm_value = 1
            else:
                this_max_norm_value = max_norm_value

and changed the next couple of lines to use the new this_max_norm_value instead of max_norm_value:

            norm = updown.norm().clamp(min=this_max_norm_value / 2)

            desired = torch.clamp(norm, max=this_max_norm_value)

This seemed to immediately get rid of the constrasty look, and let me increase my value for --scale_weight_norms, allowing the object I was learning to be learned far more accurately.

Maybe we need a --scale_weight_norms_single 1.0 option?

@recris
Copy link
Author

recris commented Dec 13, 2024

Like you said before, we need a way to control the learning rate at block level. I find it too easy to create a burned Lora before it properly learns the target concept(s).

Alternatively, there is a chance this could be mitigated with a better loss function - maybe the work in #294 should be revisited and translated to Flux training.

@recris
Copy link
Author

recris commented Dec 15, 2024

@araleza If you're training a style you may wish to have a look at #1838 , it could help improving learning issues.

@araleza
Copy link

araleza commented Dec 16, 2024

Hey thanks for the pointer, @recris. I'm actually learning a complex object rather than a style, but maybe I'll want a style later.

I found another advance to learning my object well that doesn't involve pushing up the key size further. I've discovered that instead of just setting my LoRA's alpha to be 0.5x the rank value, much higher alphas can work very well. Now I've got it around 1.67x the rank value I'm getting great results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants